#!/usr/bin/env python
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.


from collections import defaultdict
import io
import os
import sys
import argparse


import onnxruntime as rt
import onnxruntime.capi.onnxruntime_pybind11_state as rtpy 
from onnxruntime.capi.onnxruntime_pybind11_state import opkernel
from onnxruntime.capi.onnxruntime_pybind11_state import schemadef 
from onnxruntime.capi.onnxruntime_pybind11_state.opkernel import KernelDef 
from onnxruntime.capi.onnxruntime_pybind11_state.schemadef import OpSchema 


def format_version_range(v):
    if (v[1] >= 2147483647):
        return str(v[0])+'+'
    else:
        return '['+str(v[0])+', '+str(v[1])+']'    

def format_type_constraints(tc):
    counter = 0
    tcstr = ''
    firsttcitem = True
    for tcitem in tc:
        counter += 1
        if firsttcitem:
            firsttcitem = False
        else:
            tcstr += ', '
        tcstr += tcitem
    return tcstr

def format_param_strings(params):
    firstparam = True
    s = ''
    if params:
        for param in params:
            if firstparam:
                firstparam = False
            else:
                s += ' or '
            s += param
    return s
    
def main(args):  # type: (Type[Args]) -> None
    
    with io.open(args.output, 'w', newline='', encoding="utf-8") as fout:
        fout.write('## Supported Operators Data Types\n')
        fout.write(
            "*This file is automatically generated from the\n"
            "            [def files](/onnxruntime/core/providers/cpu/cpu_execution_provider.cc) via [this script](/tools/python/gen_opkernel_doc.py).\n"
            "            Do not modify directly and instead edit operator definitions.*\n")
        opdef = rtpy.get_all_operator_schema()
        paramdict = {}
        for schema in opdef:
            inputs = schema.inputs
            domain = schema.domain
            if (domain == ''):
                domain = 'ai.onnx.ml'
            fullname = domain+'.'+schema.name
            paramstr = '('
            firstinput = True
            if inputs:
                for inp in inputs:
                    if firstinput:
                        firstinput = False
                    else:
                        paramstr += ', '
                    paramstr += '*in* {}:**{}**'.format(inp.name, inp.typeStr)

            outputs = schema.outputs
            if outputs:
                for outp in outputs:
                    if firstinput:
                        firstinput = False
                    else:
                        paramstr += ', '
                    paramstr += '*out* {}:**{}**'.format(outp.name, outp.typeStr)

            paramstr += ')'
            paramset = paramdict.get(fullname,None)
            if paramset == None:
                paramdict[fullname] = set()
            
            paramdict[fullname].add(paramstr)

        index = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) 
        for op in rtpy.get_all_opkernel_def():
            domain = op.domain
            if (domain == ''):
                domain = 'ai.onnx.ml'
            index[op.provider][domain][op.op_name].append(op)

               
        fout.write('\n')
        for provider, domainmap in sorted(index.items()):
            fout.write('\n\n## Operators implemented by '+provider+'\n\n')
            fout.write('| Op Name | Parameters | OpSet Version | Types Supported |\n')
            fout.write('|---------|------------|---------------|-----------------|\n')
            for domain, namemap in sorted(domainmap.items()):
                fout.write('**Operator Domain:** *'+domain+'*\n')
                for name, ops in sorted(namemap.items()):
                    last_version = (0,0)
                    version_type_index = defaultdict(lambda: defaultdict(set))
                    for op in ops: 
                        formatted_version_range = format_version_range(op.version_range)
                        for tname,tclist in op.type_constraints.items():
                            for c in tclist:
                                version_type_index[formatted_version_range][tname].add(c)

                    namefirsttime = True
                    for version, typemap in sorted(version_type_index.items()):
                        versionfirsttime = True
                        for tname, tcset in sorted(typemap.items()):
                            if (namefirsttime):
                                params = paramdict.get(domain+'.'+name, None)
                                fout.write('|'+name+'|'+format_param_strings(params) +'|')
                                namefirsttime = False
                            else:
                                fout.write('| | |')
                            if (versionfirsttime):
                                versionfirsttime = False
                                fout.write(version+'|')
                            else:
                                fout.write('|')

                            tclist = []
                            for tc in tcset:
                                tclist.append(tc)
                            fout.write('**'+tname+'** = '+format_type_constraints(tclist)+'|\n')
                        
                fout.write('| |\n| |\n')
        

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='ONNX Runtime Operator Kernel Documentation Generator')
    parser.add_argument('--output_path', help='output markdown file path', 
                        default=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'OperatorKernels.md')
                       )
    args = parser.parse_args()


    class Args(object):
        output = args.output_path
    main(Args)
